// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Declaration of PacketThinnerCalculator.

#include <cmath>  // for ceil
#include <cstdint>
#include <memory>

#include "absl/log/absl_check.h"
#include "mediapipe/calculators/core/packet_thinner_calculator.pb.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/options_util.h"

namespace mediapipe {

namespace {
const double kTimebaseUs = 1000000;  // Microseconds.
const char* const kOptionsTag = "OPTIONS";
const char* const kPeriodTag = "PERIOD";
}  // namespace

// This calculator is used to thin an input stream of Packets.
// An example application would be to sample decoded frames of video
// at a coarser temporal resolution. Unless otherwise stated, all
// timestamps are in units of microseconds.
//
// Thinning can be accomplished in one of two ways:
// 1) asynchronous thinning (known below as async):
//    Algorithm does not rely on a master clock and is parameterized only
//    by a single option -- the period.  Once a packet is emitted, the
//    thinner will discard subsequent packets for the duration of the period
//    [Analogous to a refractory period during which packet emission is
//    suppressed.]
//    Packets arriving before start_time are discarded, as are packets
//    arriving at or after end_time.
// 2) synchronous thinning (known below as sync):
//    There are two variants of this algorithm, both parameterized by a
//    start_time and a period.  As in (1), packets arriving before start_time
//    or at/after end_time are discarded.  Otherwise, at most one packet is
//    emitted during a period, centered at timestamps generated by the
//    expression:
//      start_time + i * period  [where i is a non-negative integer]
//    During each period, the packet closest to the generated timestamp is
//    emitted (latest in the case of ties).  In the first variant
//    (sync_output_timestamps = true), the emitted packet is output at the
//    generated timestamp.  In the second variant, the packet is output at
//    its original timestamp.  Both variants emit exactly the same packets,
//    but at different timestamps.
//
// Thinning period can be provided in the calculator options or via a
// side packet with the tag "PERIOD".
//
// Calculator options provided optionally with the "OPTIONS" input
// sidepacket tag will be merged with this calculator's node options, i.e.,
// singular fields of the side packet will overwrite the options defined in the
// node, and repeated fields will concatenate.
//
// Example config:
// node {
//   calculator: "PacketThinnerCalculator"
//   input_side_packet: "OPTIONS:calculator_options"
//   input_stream: "signal"
//   output_stream: "output"
//   options {
//     [mediapipe.PacketThinnerCalculatorOptions.ext] {
//       thinner_type: SYNC
//       period: 10
//       sync_output_timestamps: true
//       update_frame_rate: false
//     }
//   }
// }
class PacketThinnerCalculator : public CalculatorBase {
 public:
  PacketThinnerCalculator() {}
  ~PacketThinnerCalculator() override {}

  static absl::Status GetContract(CalculatorContract* cc) {
    if (cc->InputSidePackets().HasTag(kOptionsTag)) {
      cc->InputSidePackets().Tag(kOptionsTag).Set<CalculatorOptions>();
    }
    cc->Inputs().Index(0).SetAny();
    cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
    if (cc->InputSidePackets().HasTag(kPeriodTag)) {
      cc->InputSidePackets().Tag(kPeriodTag).Set<int64_t>();
    }
    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) override;
  absl::Status Close(CalculatorContext* cc) override;
  absl::Status Process(CalculatorContext* cc) override {
    if (cc->InputTimestamp() < start_time_) {
      return absl::OkStatus();  // Drop packets before start_time_.
    } else if (cc->InputTimestamp() >= end_time_) {
      if (!cc->Outputs().Index(0).IsClosed()) {
        cc->Outputs()
            .Index(0)
            .Close();  // No more Packets will be output after end_time_.
      }
      return absl::OkStatus();
    } else {
      return thinner_type_ == PacketThinnerCalculatorOptions::ASYNC
                 ? AsyncThinnerProcess(cc)
                 : SyncThinnerProcess(cc);
    }
  }

 private:
  // Implementation of ASYNC and SYNC versions of thinner algorithm.
  absl::Status AsyncThinnerProcess(CalculatorContext* cc);
  absl::Status SyncThinnerProcess(CalculatorContext* cc);

  // Cached option.
  PacketThinnerCalculatorOptions::ThinnerType thinner_type_;

  // Given a Timestamp, finds the closest sync Timestamp
  // based on start_time_ and period_.  This can be earlier or
  // later than given Timestamp, but is guaranteed to be within
  // half a period_.
  Timestamp NearestSyncTimestamp(Timestamp now) const;

  // Cached option used by both async and sync thinners.
  TimestampDiff period_;  // Interval during which only one packet is emitted.
  Timestamp start_time_;  // Cached option - default Timestamp::Min()
  Timestamp end_time_;    // Cached option - default Timestamp::Max()

  // Only used by async thinner:
  Timestamp next_valid_timestamp_;  // Suppress packets until this timestamp.

  // Only used by sync thinner:
  Packet saved_packet_;          // Best packet not yet emitted.
  bool sync_output_timestamps_;  // Cached option.
};
REGISTER_CALCULATOR(PacketThinnerCalculator);

namespace {
TimestampDiff abs(TimestampDiff t) { return t < 0 ? -t : t; }
}  // namespace

absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) {
  PacketThinnerCalculatorOptions options = mediapipe::tool::RetrieveOptions(
      cc->Options<PacketThinnerCalculatorOptions>(), cc->InputSidePackets(),
      kOptionsTag);

  thinner_type_ = options.thinner_type();
  // This check enables us to assume only two thinner types exist in Process()
  ABSL_CHECK(thinner_type_ == PacketThinnerCalculatorOptions::ASYNC ||
             thinner_type_ == PacketThinnerCalculatorOptions::SYNC)
      << "Unsupported thinner type.";

  if (thinner_type_ == PacketThinnerCalculatorOptions::ASYNC) {
    // ASYNC thinner outputs packets with the same timestamp as their input so
    // its safe to SetOffset(0). SYNC thinner manipulates timestamps of its
    // output so we don't do this for that case.
    cc->SetOffset(0);
  }

  if (cc->InputSidePackets().HasTag(kPeriodTag)) {
    period_ =
        TimestampDiff(cc->InputSidePackets().Tag(kPeriodTag).Get<int64_t>());
  } else {
    period_ = TimestampDiff(options.period());
  }
  ABSL_CHECK_LT(TimestampDiff(0), period_)
      << "Specified period must be positive.";

  if (options.has_start_time()) {
    start_time_ = Timestamp(options.start_time());
  } else if (thinner_type_ == PacketThinnerCalculatorOptions::ASYNC) {
    start_time_ = Timestamp::Min();
  } else {
    start_time_ = Timestamp(0);
  }

  end_time_ =
      options.has_end_time() ? Timestamp(options.end_time()) : Timestamp::Max();
  ABSL_CHECK_LT(start_time_, end_time_)
      << "Invalid PacketThinner: start_time must be earlier than end_time";

  sync_output_timestamps_ = options.sync_output_timestamps();

  next_valid_timestamp_ = start_time_;
  // Drop packets until this time.
  cc->Outputs().Index(0).SetNextTimestampBound(start_time_);

  if (!cc->Inputs().Index(0).Header().IsEmpty()) {
    if (options.update_frame_rate()) {
      const VideoHeader& video_header =
          cc->Inputs().Index(0).Header().Get<VideoHeader>();
      double new_frame_rate;
      if (thinner_type_ == PacketThinnerCalculatorOptions::ASYNC) {
        new_frame_rate =
            video_header.frame_rate /
            ceil(video_header.frame_rate * options.period() / kTimebaseUs);
      } else {
        const double sampling_rate = kTimebaseUs / options.period();
        new_frame_rate = video_header.frame_rate < sampling_rate
                             ? video_header.frame_rate
                             : sampling_rate;
      }
      std::unique_ptr<VideoHeader> header(new VideoHeader);
      header->format = video_header.format;
      header->width = video_header.width;
      header->height = video_header.height;
      header->duration = video_header.duration;
      header->frame_rate = new_frame_rate;
      cc->Outputs().Index(0).SetHeader(Adopt(header.release()));
    } else {
      cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header());
    }
  }

  return absl::OkStatus();
}

absl::Status PacketThinnerCalculator::Close(CalculatorContext* cc) {
  // Emit any saved packets before quitting.
  if (!saved_packet_.IsEmpty()) {
    // Only sync thinner should have saved packets.
    ABSL_CHECK_EQ(PacketThinnerCalculatorOptions::SYNC, thinner_type_);
    if (sync_output_timestamps_) {
      cc->Outputs().Index(0).AddPacket(
          saved_packet_.At(NearestSyncTimestamp(saved_packet_.Timestamp())));
    } else {
      cc->Outputs().Index(0).AddPacket(saved_packet_);
    }
  }
  return absl::OkStatus();
}

absl::Status PacketThinnerCalculator::AsyncThinnerProcess(
    CalculatorContext* cc) {
  if (cc->InputTimestamp() >= next_valid_timestamp_) {
    cc->Outputs().Index(0).AddPacket(
        cc->Inputs().Index(0).Value());  // Emit current packet.
    next_valid_timestamp_ = cc->InputTimestamp() + period_;
    // Guaranteed not to emit packets seen during refractory period.
    cc->Outputs().Index(0).SetNextTimestampBound(next_valid_timestamp_);
  }
  return absl::OkStatus();
}

absl::Status PacketThinnerCalculator::SyncThinnerProcess(
    CalculatorContext* cc) {
  if (saved_packet_.IsEmpty()) {
    // If no packet has been saved, store the current packet.
    saved_packet_ = cc->Inputs().Index(0).Value();
    cc->Outputs().Index(0).SetNextTimestampBound(
        sync_output_timestamps_ ? NearestSyncTimestamp(cc->InputTimestamp())
                                : cc->InputTimestamp());
  } else {
    // Saved packet exists -- update or emit.
    const Timestamp saved = saved_packet_.Timestamp();
    const Timestamp saved_sync = NearestSyncTimestamp(saved);
    const Timestamp now = cc->InputTimestamp();
    const Timestamp now_sync = NearestSyncTimestamp(now);
    ABSL_CHECK_LE(saved_sync, now_sync);
    if (saved_sync == now_sync) {
      // Saved Packet is in same interval as current packet.
      // Replace saved packet with current if it is at least as
      // central as the saved packet wrt temporal interval.
      // [We break ties in favor of fresher packets]
      if (abs(now - now_sync) <= abs(saved - saved_sync)) {
        saved_packet_ = cc->Inputs().Index(0).Value();
      }
    } else {
      // Saved packet is the best packet from earlier interval: emit!
      if (sync_output_timestamps_) {
        cc->Outputs().Index(0).AddPacket(saved_packet_.At(saved_sync));
        cc->Outputs().Index(0).SetNextTimestampBound(now_sync);
      } else {
        cc->Outputs().Index(0).AddPacket(saved_packet_);
        cc->Outputs().Index(0).SetNextTimestampBound(now);
      }
      // Current packet is the first one we've seen from new interval -- save!
      saved_packet_ = cc->Inputs().Index(0).Value();
    }
  }
  return absl::OkStatus();
}

Timestamp PacketThinnerCalculator::NearestSyncTimestamp(Timestamp now) const {
  ABSL_CHECK_NE(start_time_, Timestamp::Unset())
      << "Method only valid for sync thinner calculator.";

  // Computation is done using int64 arithmetic.  No easy way to avoid
  // since Timestamps don't support div and multiply.
  const int64_t now64 = now.Value();
  const int64_t start64 = start_time_.Value();
  const int64_t period64 = period_.Value();
  ABSL_CHECK_LE(0, period64);

  // Round now64 to its closest interval (units of period64).
  int64_t sync64 =
      (now64 - start64 + period64 / 2) / period64 * period64 + start64;
  ABSL_CHECK_LE(abs(now64 - sync64), period64 / 2)
      << "start64: " << start64 << "; now64: " << now64
      << "; sync64: " << sync64;

  return Timestamp(sync64);
}

}  // namespace mediapipe
